R’s built-in Diamonds dataset contains about five thousand observations of diamond pieces. Each observation records ten attributes such as price (in USD), carat (weight), quality of the cut, and so on. I could not find what year it was collected.
# 1. load and quick look
data("diamonds")
View(diamonds)
?diamonds
summary(diamonds)
## carat cut color clarity depth
## Min. :0.2000 Fair : 1610 D: 6775 SI1 :13065 Min. :43.00
## 1st Qu.:0.4000 Good : 4906 E: 9797 VS2 :12258 1st Qu.:61.00
## Median :0.7000 Very Good:12082 F: 9542 SI2 : 9194 Median :61.80
## Mean :0.7979 Premium :13791 G:11292 VS1 : 8171 Mean :61.75
## 3rd Qu.:1.0400 Ideal :21551 H: 8304 VVS2 : 5066 3rd Qu.:62.50
## Max. :5.0100 I: 5422 VVS1 : 3655 Max. :79.00
## J: 2808 (Other): 2531
## table price x y
## Min. :43.00 Min. : 326 Min. : 0.000 Min. : 0.000
## 1st Qu.:56.00 1st Qu.: 950 1st Qu.: 4.710 1st Qu.: 4.720
## Median :57.00 Median : 2401 Median : 5.700 Median : 5.710
## Mean :57.46 Mean : 3933 Mean : 5.731 Mean : 5.735
## 3rd Qu.:59.00 3rd Qu.: 5324 3rd Qu.: 6.540 3rd Qu.: 6.540
## Max. :95.00 Max. :18823 Max. :10.740 Max. :58.900
##
## z
## Min. : 0.000
## 1st Qu.: 2.910
## Median : 3.530
## Mean : 3.539
## 3rd Qu.: 4.040
## Max. :31.800
##
sum(is.na(diamonds)) # no missing values
## [1] 0
hist(diamonds$carat) # right skewed with a few outliers
hist(diamonds$price) # right skewed
hist(diamonds$depth) # mostly normal
# it makes sense for measures suggesting quality (price, weight) to be right skewed => the better the diamond the rarer (and vice versa)
# you can tell the data was well collected or pre cleaned
Two easy and intuitive ways to explore relationships between variables are graphs and summary tables.
We find a clear positive relationship between the weight and the price of the diamond (duh? haha)
The effect of the quality of the cut is only visible once we control for the weight
# 2. Any Ideas?
plot(diamonds$depth, diamonds$price) # no clear relationship
plot(diamonds$table, diamonds$price) # same
dia_bycut <- diamonds %>%
group_by(cut) %>%
select(price, carat) %>%
summarise_all(list(mean = mean, min = min, max =max))
kable(dia_bycut)
| cut | price_mean | carat_mean | price_min | carat_min | price_max | carat_max |
|---|---|---|---|---|---|---|
| Fair | 4358.758 | 1.0461366 | 337 | 0.22 | 18574 | 5.01 |
| Good | 3928.864 | 0.8491847 | 327 | 0.23 | 18788 | 3.01 |
| Very Good | 3981.760 | 0.8063814 | 336 | 0.20 | 18818 | 4.00 |
| Premium | 4584.258 | 0.8919549 | 326 | 0.20 | 18823 | 4.01 |
| Ideal | 3457.542 | 0.7028370 | 326 | 0.20 | 18806 | 3.50 |
histdepth_bycut <- diamonds %>%
ggplot(aes(x = depth)) +
geom_histogram(binwidth = 0.5) +
facet_wrap(~ cut) +
xlab("Depth") + ylab("Count")
histdepth_bycut
histprice_bycut <- diamonds %>%
ggplot(aes(x = price, fill = cut)) +
geom_histogram(binwidth = 100) +
xlab("Price") + ylab("Count")
histprice_bycut
scatter_price <- diamonds %>%
ggplot( aes( x=carat, y=price,
color=cut)) +
geom_point(alpha=0.7) +
xlab("Carat") + ylab("Price")+
guides(fill = guide_legend(title = "Cut"))
scatter_price
linear_price <- lm(price ~ ., data = diamonds)
summary(linear_price)
##
## Call:
## lm(formula = price ~ ., data = diamonds)
##
## Residuals:
## Min 1Q Median 3Q Max
## -21376.0 -592.4 -183.5 376.4 10694.2
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 5753.762 396.630 14.507 < 2e-16 ***
## carat 11256.978 48.628 231.494 < 2e-16 ***
## cut.L 584.457 22.478 26.001 < 2e-16 ***
## cut.Q -301.908 17.994 -16.778 < 2e-16 ***
## cut.C 148.035 15.483 9.561 < 2e-16 ***
## cut^4 -20.794 12.377 -1.680 0.09294 .
## color.L -1952.160 17.342 -112.570 < 2e-16 ***
## color.Q -672.054 15.777 -42.597 < 2e-16 ***
## color.C -165.283 14.725 -11.225 < 2e-16 ***
## color^4 38.195 13.527 2.824 0.00475 **
## color^5 -95.793 12.776 -7.498 6.59e-14 ***
## color^6 -48.466 11.614 -4.173 3.01e-05 ***
## clarity.L 4097.431 30.259 135.414 < 2e-16 ***
## clarity.Q -1925.004 28.227 -68.197 < 2e-16 ***
## clarity.C 982.205 24.152 40.668 < 2e-16 ***
## clarity^4 -364.918 19.285 -18.922 < 2e-16 ***
## clarity^5 233.563 15.752 14.828 < 2e-16 ***
## clarity^6 6.883 13.715 0.502 0.61575
## clarity^7 90.640 12.103 7.489 7.06e-14 ***
## depth -63.806 4.535 -14.071 < 2e-16 ***
## table -26.474 2.912 -9.092 < 2e-16 ***
## x -1008.261 32.898 -30.648 < 2e-16 ***
## y 9.609 19.333 0.497 0.61918
## z -50.119 33.486 -1.497 0.13448
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 1130 on 53916 degrees of freedom
## Multiple R-squared: 0.9198, Adjusted R-squared: 0.9198
## F-statistic: 2.688e+04 on 23 and 53916 DF, p-value: < 2.2e-16
set.seed (11)
diamonds$train <- sample(c(TRUE , FALSE), nrow(diamonds), replace = TRUE)
# drawing a sample of 53940 with replacement
dia_train <- filter(diamonds, train)
dia_test <- filter(diamonds, !train)
lm1 <- lm(price ~ carat + cut, data = dia_train)
lm2 <- lm(price ~ carat + cut + depth + table, data = dia_train)
lm3 <- lm(price ~ . - x - y -z, data = dia_train)
lm4 <- lm(price ~ . , data = dia_train)
# Out of sample mean squared error
lm1hat <- predict(lm1, newdata = dia_test)
mean( (dia_test$price - lm1hat)^2 )
## [1] 2295292
# we can automate the process with a function
mse <- function(model) {
yhat <- predict(model, newdata = dia_test)
sq_error <- (dia_test$price - yhat)^2
mean(sq_error)
}
mse(lm1)
## [1] 2295292
mse(lm2)
## [1] 2285418
mse(lm3)
## [1] 1337119
mse(lm4)
## [1] 1402004
# apply our function to multiple models at once
lapply(list(lm1, lm2, lm3, lm4), mse)
## [[1]]
## [1] 2295292
##
## [[2]]
## [1] 2285418
##
## [[3]]
## [1] 1337119
##
## [[4]]
## [1] 1402004
# we can save it into a table
mse_table <- as.data.frame(lapply(list(lm1, lm2, lm3, lm4), mse),
col.names = c("lm1", "lm2", "lm3", "lm4"))
library(glmnet)
## Loading required package: Matrix
##
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
##
## expand, pack, unpack
## Loaded glmnet 4.1-3
x1 <- model.matrix(~ ., select(dia_train, -price))
x2 <- model.matrix(price ~ ., dia_train)
y <- dia_train$price
lambda_grid <- 10^seq(3, -3, by = -.1)
# it's important to understand what this grid does. we are varying lambda
# from 0.01 to 1000. the reason we sequence inside the power is because
# it will make the grid more sensitive when we ideally want it to be and less
# sensitive otherwise. it will very by small increments near 0.01 and by
# large increments near 1000
ridge <- glmnet(x2, y, alpha = 0, lambda = lambda_grid)
summary(ridge)
## Length Class Mode
## a0 61 -none- numeric
## beta 1525 dgCMatrix S4
## df 61 -none- numeric
## dim 2 -none- numeric
## lambda 61 -none- numeric
## dev.ratio 61 -none- numeric
## nulldev 1 -none- numeric
## npasses 1 -none- numeric
## jerr 1 -none- numeric
## offset 1 -none- logical
## call 5 -none- call
## nobs 1 -none- numeric
# in the sequence we went from lambda high to lambda low
# then parameters of first models should be smaller:
coef(ridge)[, 1] # more shrinked
## (Intercept) (Intercept) carat cut.L cut.Q
## -9358.7185818 0.0000000 3121.6705399 428.6346699 -153.3709702
## cut.C cut^4 color.L color.Q color.C
## -0.1020196 0.6200210 -921.9242316 -284.9004911 -69.7256961
## color^4 color^5 color^6 clarity.L clarity.Q
## 15.5184250 -47.4734964 -57.1629559 2249.9796100 -729.2600017
## clarity.C clarity^4 clarity^5 clarity^6 clarity^7
## -19.8564166 52.8007675 -29.9372590 35.9108281 114.8650951
## depth table x y z
## -3.4973177 -6.3627280 706.7594218 725.0183722 841.3276792
## trainTRUE
## 0.0000000
coef(ridge)[, 40] # less shrinked
## (Intercept) (Intercept) carat cut.L cut.Q cut.C
## 4573.414644 0.000000 10869.477144 599.517551 -292.118569 103.388704
## cut^4 color.L color.Q color.C color^4 color^5
## -31.579700 -1933.089636 -665.030232 -171.691705 9.240034 -85.324238
## color^6 clarity.L clarity.Q clarity.C clarity^4 clarity^5
## -55.730843 4112.651184 -1927.661068 992.445995 -387.185503 225.347400
## clarity^6 clarity^7 depth table x y
## 6.110287 96.873016 -55.511042 -25.238539 -1419.531759 587.309083
## z trainTRUE
## -65.532631 0.000000
summary(lm1)
##
## Call:
## lm(formula = price ~ carat + cut, data = dia_train)
##
## Residuals:
## Min 1Q Median 3Q Max
## -17389.7 -788.4 -38.9 511.1 12711.9
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -2678.31 21.72 -123.297 < 2e-16 ***
## carat 7839.95 19.84 395.196 < 2e-16 ***
## cut.L 1246.49 36.44 34.207 < 2e-16 ***
## cut.Q -557.04 32.34 -17.225 < 2e-16 ***
## cut.C 365.52 28.41 12.864 < 2e-16 ***
## cut^4 79.67 22.92 3.476 0.00051 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 1508 on 26999 degrees of freedom
## Multiple R-squared: 0.8543, Adjusted R-squared: 0.8543
## F-statistic: 3.167e+04 on 5 and 26999 DF, p-value: < 2.2e-16
# the glmnet package can do cross validation on its own and pick best lambda
cv_ridge <- cv.glmnet(x2, y, alpha = 0, lambda = lambda_grid)
plot(cv_ridge)
ridge_star <- cv_ridge$lambda.min
# suggests 0.001 is best lambda (also the lambda we start with)
ridge_hat_star <- predict(ridge , s = ridge_star, model.matrix(price ~ ., dia_test))
mse_table$ridge_0.001 <- mean (( ridge_hat_star - dia_test$price)^2)
# however, this is the best lambda picked using only the training dataset
# does not mean it will perform best in the test.
# plus, it's dig-deeper-worthy that the smallest is the best lambda
# we can change the value of lambda by hand within the predict function
# s represents lambda. when s is specified the predict will give a vector
ridge_hat1 <- predict(ridge , s = 1, model.matrix(price ~ ., dia_test))
ridge_hat100 <- predict(ridge , s = 100, model.matrix(price ~ ., dia_test))
mean( (ridge_hat1 - dia_test$price)^2 )
## [1] 1298093
# see lambda = 1 preformed better than lambda = 0.001 that cv suggested
mean( (ridge_hat100 - dia_test$price)^2 )
## [1] 1446700
mse_table$ridge_1 <- mean( (ridge_hat1 - dia_test$price)^2 )
mse_table$ridge_100 <- mean( (ridge_hat100 - dia_test$price)^2 )
# when s is null, it will give a matrix (each column for a different lambda)
ridge_hat_all <- predict(ridge , s = NULL, model.matrix(price ~ ., dia_test))
head(ridge_hat_all)
## s0 s1 s2 s3 s4 s5 s6
## 1 -1963.1550 -2053.6863 -2117.1977 -2154.8000 -2169.1583 -2162.2544 -2137.8246
## 2 -1722.8065 -1778.2402 -1808.3148 -1814.0126 -1797.9644 -1761.7860 -1709.1989
## 3 -901.3231 -941.6041 -961.4892 -961.0721 -942.2261 -905.7931 -854.8382
## 4 -2221.4396 -2445.5926 -2639.6673 -2804.6815 -2943.1659 -3057.4164 -3150.8351
## 5 -1528.8689 -1618.1994 -1682.7371 -1723.7059 -1743.6107 -1744.2762 -1729.0198
## 6 -1577.8931 -1650.5645 -1700.0384 -1727.4441 -1735.0679 -1724.9148 -1700.2573
## s7 s8 s9 s10 s11 s12 s13
## 1 -2098.8453 -2048.6617 -1990.9427 -1929.3347 -1866.447 -1805.4950 -1748.7768
## 2 -1642.9983 -1566.5231 -1483.6282 -1398.3432 -1313.278 -1232.1784 -1157.6720
## 3 -791.5077 -718.5689 -639.4267 -557.6252 -475.945 -397.7218 -325.4761
## 4 -3226.1715 -3286.1836 -3333.5526 -3370.6839 -3399.506 -3421.9207 -3439.4395
## 5 -1700.4673 -1661.5659 -1615.5464 -1565.6560 -1514.131 -1463.8726 -1416.9440
## 6 -1663.9733 -1619.2059 -1569.2072 -1517.3185 -1465.081 -1415.4060 -1370.0780
## s14 s15 s16 s17 s18 s19
## 1 -1695.6042 -1647.7884 -1606.0135 -1570.2100 -1539.65914 -1513.00478
## 2 -1088.6012 -1027.0366 -973.7140 -928.4324 -890.19119 -857.25607
## 3 -258.3487 -198.3281 -146.1429 -101.6472 -63.92877 -31.34733
## 4 -3452.5984 -3462.6550 -3470.4463 -3476.5303 -3481.22729 -3484.63602
## 5 -1372.7454 -1332.9121 -1298.0934 -1268.2678 -1242.83972 -1220.66180
## 6 -1328.1434 -1290.9394 -1258.9233 -1231.9239 -1209.26562 -1189.84356
## s20 s21 s22 s23 s24 s25
## 1 -1490.149949 -1470.78066 -1454.72344 -1441.62076 -1430.66693 -1421.55202
## 2 -829.433525 -806.29879 -787.58171 -772.75831 -760.84322 -751.40559
## 3 -3.712143 19.39832 38.25075 53.34083 65.63038 75.53441
## 4 -3487.136683 -3488.99096 -3490.43910 -3491.62010 -3492.52681 -3493.23890
## 5 -1201.679111 -1185.64781 -1172.43655 -1161.74200 -1152.88449 -1145.60132
## 6 -1173.543014 -1160.11498 -1149.39967 -1141.06484 -1134.51643 -1129.49011
## s26 s27 s28 s29 s30 s31
## 1 -1414.10196 -1407.96423 -1402.91255 -1398.70990 -1395.3327 -1392.5313
## 2 -744.09379 -738.39800 -733.96793 -730.49618 -727.8437 -725.7528
## 3 83.36116 89.58679 94.53521 98.50498 101.6030 104.0967
## 4 -3493.82877 -3494.30535 -3494.68968 -3494.99052 -3495.2445 -3495.4408
## 5 -1139.72635 -1134.94808 -1131.06380 -1127.87137 -1125.3334 -1123.2476
## 6 -1125.73998 -1122.93844 -1120.85798 -1119.31291 -1118.1924 -1117.3568
## s32 s33 s34 s35 s36 s37 s38
## 1 -1390.2892 -1388.4396 -1386.9371 -1385.8067 -1384.7873 -1384.0324 -1383.4376
## 2 -724.1504 -722.8829 -721.8939 -721.1663 -720.5297 -720.0717 -719.7154
## 3 106.0452 107.6144 108.8615 109.7909 110.6118 111.2127 111.6831
## 4 -3495.6050 -3495.7314 -3495.8322 -3495.9260 -3495.9825 -3496.0386 -3496.0846
## 5 -1121.5925 -1120.2367 -1119.1429 -1118.3250 -1117.5882 -1117.0468 -1116.6213
## 6 -1116.7510 -1116.2977 -1115.9652 -1115.7319 -1115.5346 -1115.4032 -1115.3039
## s39 s40 s41 s42 s43 s44 s45
## 1 -1382.9239 -1382.5639 -1382.2761 -1381.9881 -1381.8472 -1381.6333 -1381.4925
## 2 -719.4124 -719.2031 -719.0357 -718.8691 -718.7903 -718.6664 -718.5877
## 3 112.0850 112.3657 112.5901 112.8125 112.9216 113.0859 113.1929
## 4 -3496.1150 -3496.1470 -3496.1721 -3496.1833 -3496.2021 -3496.2073 -3496.2120
## 5 -1116.2537 -1115.9982 -1115.7938 -1115.5875 -1115.4900 -1115.3355 -1115.2356
## 6 -1115.2204 -1115.1666 -1115.1230 -1115.0779 -1115.0585 -1115.0259 -1115.0056
## s46 s47 s48 s49 s50 s51 s52
## 1 -1381.3532 -1381.2156 -1381.0807 -1380.9491 -1380.8211 -1380.6971 -1380.5772
## 2 -718.5101 -718.4349 -718.3631 -718.2954 -718.2321 -718.1736 -718.1197
## 3 113.2978 113.4000 113.4986 113.5929 113.6826 113.7676 113.8480
## 4 -3496.2133 -3496.2104 -3496.2044 -3496.1962 -3496.1864 -3496.1756 -3496.1642
## 5 -1115.1359 -1115.0372 -1114.9405 -1114.8463 -1114.7552 -1114.6673 -1114.5828
## 6 -1114.9860 -1114.9674 -1114.9504 -1114.9357 -1114.9236 -1114.9142 -1114.9077
## s53 s54 s55 s56 s57 s58 s59
## 1 -1380.4613 -1380.3493 -1380.2411 -1380.1365 -1380.0353 -1379.9373 -1379.8422
## 2 -718.0704 -718.0255 -717.9848 -717.9481 -717.9150 -717.8854 -717.8589
## 3 113.9238 113.9952 114.0624 114.1257 114.1854 114.2416 114.2948
## 4 -3496.1525 -3496.1408 -3496.1291 -3496.1177 -3496.1065 -3496.0957 -3496.0852
## 5 -1114.5016 -1114.4236 -1114.3488 -1114.2771 -1114.2081 -1114.1418 -1114.0779
## 6 -1114.9039 -1114.9028 -1114.9043 -1114.9082 -1114.9143 -1114.9225 -1114.9326
## s60
## 1 -1379.7499
## 2 -717.8353
## 3 114.3450
## 4 -3496.0751
## 5 -1114.0163
## 6 -1114.9444
# let's automate out of sample mse calculation and finding best lambda
bestmse_shrink <- function(model, hat_matrix) {
shrink_hats <- predict(model, s = NULL, model.matrix(price ~ ., dia_test))
mses <- c()
for (i in 1:ncol(shrink_hats)) {
mses <- c(mses, mean( (hat_matrix[,i] - dia_test$price)^2 ) )
}
index_min <- c(which.min(mses), min(mses))
return(index_min)
}
# first is index of the best lambda second is the best out of sample mse it gives
best_ridge <- bestmse_shrink(ridge, ridge_hat_all)
# so it's the 24th lambda in the list that gives best omse:
lambda_grid[24]
## [1] 5.011872
mse_table$ridge_5 <- best_ridge[2]
#
lasso <- glmnet(x2, y, alpha = 1, lambda = lambda_grid)
summary(lasso)
## Length Class Mode
## a0 61 -none- numeric
## beta 1525 dgCMatrix S4
## df 61 -none- numeric
## dim 2 -none- numeric
## lambda 61 -none- numeric
## dev.ratio 61 -none- numeric
## nulldev 1 -none- numeric
## npasses 1 -none- numeric
## jerr 1 -none- numeric
## offset 1 -none- logical
## call 5 -none- call
## nobs 1 -none- numeric
coef(lasso)[, 1] # it eliminated out all variables except carat
## (Intercept) (Intercept) carat cut.L cut.Q cut.C
## -548.5598 0.0000 5601.3582 0.0000 0.0000 0.0000
## cut^4 color.L color.Q color.C color^4 color^5
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## color^6 clarity.L clarity.Q clarity.C clarity^4 clarity^5
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## clarity^6 clarity^7 depth table x y
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## z trainTRUE
## 0.0000 0.0000
coef(lasso)[, 10] # kept carat and some color, clarity and one cut aspects
## (Intercept) (Intercept) carat cut.L cut.Q cut.C
## -2677.82669 0.00000 8194.94827 148.06335 0.00000 0.00000
## cut^4 color.L color.Q color.C color^4 color^5
## 0.00000 -1032.62945 -32.59695 0.00000 0.00000 0.00000
## color^6 clarity.L clarity.Q clarity.C clarity^4 clarity^5
## 0.00000 2601.03959 -685.06392 0.00000 0.00000 0.00000
## clarity^6 clarity^7 depth table x y
## 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
## z trainTRUE
## 0.00000 0.00000
coef(lasso)[, 40] # eliminated none but still shrinked
## (Intercept) (Intercept) carat cut.L cut.Q cut.C
## 4601.497433 0.000000 10869.007600 600.725554 -294.378963 107.382448
## cut^4 color.L color.Q color.C color^4 color^5
## -28.527528 -1932.412433 -664.244657 -170.883619 9.281562 -85.184584
## color^6 clarity.L clarity.Q clarity.C clarity^4 clarity^5
## -55.618185 4113.154963 -1926.805899 991.688260 -386.565878 224.745655
## clarity^6 clarity^7 depth table x y
## 5.853690 96.305213 -55.885453 -25.298601 -1318.067353 484.915506
## z trainTRUE
## -64.325110 0.000000
# use cv to find best
cv_lasso <- cv.glmnet(x2, y, alpha = 1, lambda = lambda_grid)
plot(cv_lasso)
lasso_star <- cv_ridge$lambda.min
lasso_hat_star <- predict(lasso, s = lasso_star, model.matrix(price ~ ., dia_test))
mse_table$lasso_0.001 <- mean (( lasso_hat_star - dia_test$price)^2)
# check other lamdas for out of sample performance
lasso_hat1 <- predict(lasso , s = 1, model.matrix(price ~ ., dia_test))
lasso_hat100 <- predict(lasso , s = 100, model.matrix(price ~ ., dia_test))
mse_table$lasso_1 <- mean (( lasso_hat1 - dia_test$price)^2)
mse_table$lasso_100 <- mean (( lasso_hat100 - dia_test$price)^2)
# let's check them all
lasso_hat_all <- predict(lasso , s = NULL, model.matrix(price ~ ., dia_test))
best_lasso <- bestmse_shrink(lasso, lasso_hat_all)
best_lasso
## [1] 31 1272705
lambda_grid[31]
## [1] 1
# 31st lambda which is one and we already saved it into mse table
# mse_table$lasso_1 <- best_lasso[2]
mse_all <- as.data.frame(t(mse_table))
min(mse_all$V1)
## [1] 1272705
kable(mse_all)
| V1 | |
|---|---|
| lm1 | 2295292 |
| lm2 | 2285418 |
| lm3 | 1337119 |
| lm4 | 1402004 |
| ridge_0.001 | 1323918 |
| ridge_1 | 1298093 |
| ridge_100 | 1446700 |
| ridge_5 | 1281099 |
| lasso_0.001 | 1316549 |
| lasso_1 | 1272705 |
| lasso_100 | 1504542 |
coef(lasso)[, 31] # let's check our champion coefficients
## (Intercept) (Intercept) carat cut.L cut.Q cut.C
## 4578.747552 0.000000 10802.133061 604.230651 -301.421697 122.545747
## cut^4 color.L color.Q color.C color^4 color^5
## -12.704798 -1924.422334 -657.383713 -165.748616 8.984044 -83.355879
## color^6 clarity.L clarity.Q clarity.C clarity^4 clarity^5
## -54.466311 4108.124757 -1913.183733 980.190071 -379.827149 218.440085
## clarity^6 clarity^7 depth table x y
## 3.934344 92.900303 -56.953337 -25.249497 -834.788004 16.013504
## z trainTRUE
## -47.211912 0.000000
# check scale of the variables maybe can normalize to make
# the shrinkage more reliable
normalize <- function(x) {
normal_x <- (x - mean(x)) / sd(x)
return(normal_x)
}
# normalize(diamonds$carat)
# names(diamonds)
dia_norm <- diamonds %>%
mutate(
norm_carat = normalize(carat),
norm_depth = normalize(depth),
norm_table = normalize(table),
price_norm = normalize(price),
norm_x = normalize(x),
norm_y = normalize(y),
norm_z = normalize(z)
)
# if we build these models after having normalized, I think we'd expect
# to shrink depth and table more and carat less.
# you can try and check yourself
plot(diamonds$carat, diamonds$price)
plot((diamonds$carat)^2, diamonds$price)
plot(log(diamonds$carat), diamonds$price)
plot(log(diamonds$carat), log(diamonds$price))
scatter_lnprice <- diamonds %>%
ggplot( aes( x=log(carat), y=log(price),
color=cut)) +
geom_point(alpha=0.7) +
xlab("Natural Log of Carat") + ylab("Natural Log of Price")+
guides(fill = guide_legend(title = "Cut"))
scatter_lnprice
# now let's try some models with logs
diamonds_log <- diamonds %>%
mutate(ln_carat = log(carat), ln_price = log(price)) %>%
select(-carat)
dialog_train <- filter(diamonds_log, train)
dialog_test <- filter(diamonds_log, !train)
lm3_log <- lm(ln_price ~ . - x - y -z, data = dialog_train)
lm4_log <- lm(ln_price ~ . , data = dialog_train)
# Out of sample mean squared error
# remember the predictions here are log prices so exponentiate
lm3_loghat <- exp(predict(lm3_log, newdata = dialog_test))
## Warning in predict.lm(lm3_log, newdata = dialog_test): prediction from a rank-
## deficient fit may be misleading
mean( (dia_test$price - lm3_loghat)^2 )
## [1] 768614.7
# the dia_test has the same test prices as dialog_test just not logged
lm4_loghat <- exp(predict(lm4_log, newdata = dialog_test))
## Warning in predict.lm(lm4_log, newdata = dialog_test): prediction from a rank-
## deficient fit may be misleading
mean( (dia_test$price - lm4_loghat)^2 )
## [1] 658183.5
# these are much better mse's than all the other models we had
# you can try running a log log regression for ridge and lasso now
I hope this was helpful yall. Shoot me R or assignment questions whenever. See you next review session.
Peace,
Lutfi